# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

""" Produce the dataset for cifar10 """
import os
from urllib.error import URLError

from mindspore.common import dtype as mstype
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2

from mindvision.common.dataset.dataloader import DataLoader
from mindvision.classification.dataset.meta_dataset import Meta

class Cifar10DataLoader(DataLoader, metaclass=Meta):
    """Cifar10 Dataset"""
    mirrors = [
        'http://www.cs.toronto.edu/~kriz/'
    ]

    resources = [
        ("cifar-10-python.tar.gz", "c58f30108f718f92721af3b95e74349a")
    ]

    index2label = {0: 'plane', 1: 'car', 2: 'bird', 3: 'cat', 4: 'deer',
                   5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}

    def __init__(self, dataset=None, train=True, transform=None, target_transform=None, batch_size=32, repeat_num=1,
                 num_parallel_workers=None, download=False):

        super(Cifar10DataLoader, self).__init__(dataset=dataset, train=train, transform=transform,
                                                target_transform=target_transform, batch_size=batch_size,
                                                repeat_num=repeat_num, num_parallel_workers=num_parallel_workers,
                                                download=download)

        if download:
            self._download()

    def _download(self):
        """Download the cifar10 data if it doesn't exist already"""
        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)

        # download files
        for filename, md5 in self.resources:
            for mirror in self.mirrors:
                url = "{}{}".format(mirror, filename)
                try:
                    print("Downloading {}".format(url))
                    self._download_url(
                        url, root=self.raw_folder,
                        filename=filename,
                        md5=md5
                    )
                except URLError as error:
                    print("Failed to download (trying next):\n{}".format(error))
                    continue
                finally:
                    print()
                break
            else:
                raise RuntimeError("Error downloading {}".format(filename))

    def _default_transform(self):
        """default transform"""

        # define map operations
        trans = [
            C.RandomCrop((32, 32), (4, 4, 4, 4)),
            C.RandomHorizontalFlip(prob=0.5),
            C.Resize((224, 224)),
            C.Rescale(1.0 / 255.0, 0.0),
            C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            C.HWC2CHW()
        ]
        return trans

    def _transforms(self):
        """transforms"""
        assert self.dataset, "dataset is None"
        trans = self.transform if self.transform else self._default_transform()
        self.dataset = self.dataset.map(operations=trans, input_columns="image",
                                        num_parallel_workers=self.num_parallel_workers)
        type_cast_op = self.target_transform if self.target_transform else C2.TypeCast(mstype.int32)
        self.dataset = self.dataset.map(operations=type_cast_op,
                                        input_columns="label",
                                        num_parallel_workers=self.num_parallel_workers)

    def _pipeline(self):
        """pipeline"""
        self._transforms()
        self.dataset = self.dataset.batch(self.batch_size, drop_remainder=True)
        self.dataset = self.dataset.repeat(self.repeat_num)
